{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "papermill": {
     "duration": 0.012019,
     "end_time": "2021-02-15T11:01:41.761156",
     "exception": false,
     "start_time": "2021-02-15T11:01:41.749137",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "\n",
    "\n",
    "This notebook contains an example for teaching.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "_execution_state": "idle",
    "_uuid": "051d70d956493feee0c6d64651c6a088724dca2a",
    "papermill": {
     "duration": 0.010774,
     "end_time": "2021-02-15T11:01:41.782833",
     "exception": false,
     "start_time": "2021-02-15T11:01:41.772059",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# Penalized Linear Regressions: A Simulation Experiment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "papermill": {
     "duration": 0.010616,
     "end_time": "2021-02-15T11:01:41.804126",
     "exception": false,
     "start_time": "2021-02-15T11:01:41.793510",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Data Generating Process: Approximately Sparse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.603407,
     "end_time": "2021-02-15T11:01:42.418301",
     "exception": false,
     "start_time": "2021-02-15T11:01:41.814894",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "set.seed(1)\n",
    "\n",
    "n = 100;\n",
    "p = 400;\n",
    "\n",
    "Z= runif(n)-1/2;\n",
    "W = matrix(runif(n*p)-1/2, n, p);\n",
    "\n",
    "\n",
    "\n",
    "beta = 1/seq(1:p)^2;   # approximately sparse beta\n",
    "#beta = rnorm(p)*.2    # dense beta\n",
    "gX = exp(4*Z)+ W%*%beta;  # leading term nonlinear\n",
    "X = cbind(Z, Z^2, Z^3, W );  # polynomials in Zs will be approximating exp(4*Z)\n",
    "\n",
    "\n",
    "Y = gX + rnorm(n);    #generate Y\n",
    "\n",
    "\n",
    "plot(gX,Y, xlab=\"g(X)\", ylab=\"Y\")    #plot V vs g(X)\n",
    "\n",
    "print( c(\"theoretical R2:\", var(gX)/var(Y)))\n",
    "\n",
    "var(gX)/var(Y); #theoretical R-square in the simulation example\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "papermill": {
     "duration": 0.013571,
     "end_time": "2021-02-15T11:01:42.446308",
     "exception": false,
     "start_time": "2021-02-15T11:01:42.432737",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "We use package Glmnet to carry out predictions using cross-validated lasso, ridge, and elastic net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 2.898022,
     "end_time": "2021-02-15T11:01:45.358083",
     "exception": false,
     "start_time": "2021-02-15T11:01:42.460061",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "library(glmnet)\n",
    "fit.lasso.cv   <- cv.glmnet(X, Y, family=\"gaussian\", alpha=1)  # family gaussian means that we'll be using square loss\n",
    "fit.ridge   <- cv.glmnet(X, Y, family=\"gaussian\", alpha=0)     # family gaussian means that we'll be using square loss\n",
    "fit.elnet   <- cv.glmnet(X, Y, family=\"gaussian\", alpha=.5)    # family gaussian means that we'll be using square loss\n",
    "\n",
    "yhat.lasso.cv    <- predict(fit.lasso.cv, newx = X)            # predictions\n",
    "yhat.ridge   <- predict(fit.ridge, newx = X)\n",
    "yhat.elnet   <- predict(fit.elnet, newx = X)\n",
    "\n",
    "MSE.lasso.cv <- summary(lm((gX-yhat.lasso.cv)^2~1))$coef[1:2]  # report MSE and standard error for MSE for approximating g(X)\n",
    "MSE.ridge <- summary(lm((gX-yhat.ridge)^2~1))$coef[1:2]        # report MSE and standard error for MSE for approximating g(X)\n",
    "MSE.elnet <- summary(lm((gX-yhat.elnet)^2~1))$coef[1:2]        # report MSE and standard error for MSE for approximating g(X)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "papermill": {
     "duration": 0.01429,
     "end_time": "2021-02-15T11:01:45.388902",
     "exception": false,
     "start_time": "2021-02-15T11:01:45.374612",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "Here we compute the lasso and ols post lasso using plug-in choices for penalty levels, using package hdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 11.419928,
     "end_time": "2021-02-15T11:01:56.823076",
     "exception": false,
     "start_time": "2021-02-15T11:01:45.403148",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "library(hdm) \n",
    "fit.rlasso  <- rlasso(Y~X,  post=FALSE)      # lasso with plug-in penalty level\n",
    "fit.rlasso.post <- rlasso(Y~X,  post=TRUE)    # post-lasso with plug-in penalty level\n",
    "\n",
    "yhat.rlasso   <- predict(fit.rlasso)            #predict g(X) for values of X\n",
    "yhat.rlasso.post   <- predict(fit.rlasso.post)  #predict g(X) for values of X\n",
    "\n",
    "MSE.lasso <- summary(lm((gX-yhat.rlasso)^2~1))$coef[1:2]       # report MSE and standard error for MSE for approximating g(X)\n",
    "MSE.lasso.post <- summary(lm((gX-yhat.rlasso.post)^2~1))$coef[1:2]  # report MSE and standard error for MSE for approximating g(X)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "papermill": {
     "duration": 0.02899,
     "end_time": "2021-02-15T11:01:56.880825",
     "exception": false,
     "start_time": "2021-02-15T11:01:56.851835",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "Next we code up lava, which alternates the fitting of lasso and ridge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 54.672772,
     "end_time": "2021-02-15T11:02:51.582461",
     "exception": false,
     "start_time": "2021-02-15T11:01:56.909689",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "library(glmnet)\n",
    "\n",
    "lava.predict<- function(X,Y, iter=5){\n",
    "    \n",
    "g1 = predict(rlasso(X, Y, post=F))  #lasso step fits \"sparse part\"\n",
    "m1 =  predict(glmnet(X, as.vector(Y-g1), family=\"gaussian\", alpha=0, lambda =20),newx=X ) #ridge step fits the \"dense\" part\n",
    "\n",
    "    \n",
    "i=1\n",
    "while(i<= iter) {\n",
    "g1 = predict(rlasso(X, Y, post=F))   #lasso step fits \"sparse part\"\n",
    "m1 = predict(glmnet(X, as.vector(Y-g1), family=\"gaussian\",  alpha=0, lambda =20),newx=X );  #ridge step fits the \"dense\" part\n",
    "i = i+1 }\n",
    "\n",
    "return(g1+m1);\n",
    "    }\n",
    "\n",
    "\n",
    "yhat.lava = lava.predict(X,Y)\n",
    "MSE.lava <- summary(lm((gX-yhat.lava)^2~1))$coef[1:2]       # report MSE and standard error for MSE for approximating g(X)\n",
    "\n",
    "    \n",
    "MSE.lava"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.113713,
     "end_time": "2021-02-15T11:02:51.724821",
     "exception": false,
     "start_time": "2021-02-15T11:02:51.611108",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "library(xtable)\n",
    "table<- matrix(0, 6, 2)\n",
    "table[1,1:2]   <- MSE.lasso.cv\n",
    "table[2,1:2]   <- MSE.ridge\n",
    "table[3,1:2]   <- MSE.elnet\n",
    "table[4,1:2]   <- MSE.lasso\n",
    "table[5,1:2]   <- MSE.lasso.post\n",
    "table[6,1:2]   <- MSE.lava\n",
    "\n",
    "colnames(table)<- c(\"MSA\", \"S.E. for MSA\")\n",
    "rownames(table)<- c(\"Cross-Validated Lasso\", \"Cross-Validated ridge\",\"Cross-Validated elnet\",\n",
    "                    \"Lasso\",\"Post-Lasso\",\"Lava\")\n",
    "tab <- xtable(table, digits =3)\n",
    "print(tab,type=\"latex\") # set type=\"latex\" for printing table in LaTeX\n",
    "tab\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.15987,
     "end_time": "2021-02-15T11:02:51.903624",
     "exception": false,
     "start_time": "2021-02-15T11:02:51.743754",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "plot(gX, gX, pch=19, cex=1, ylab=\"predicted value\", xlab=\"true g(X)\")\n",
    "\n",
    "points(gX, yhat.rlasso, col=2, pch=18, cex = 1.5 )\n",
    "points(gX,  yhat.rlasso.post, col=3, pch=17,  cex = 1.2  )\n",
    "points( gX, yhat.lasso.cv,col=4, pch=19,  cex = 1.2 )\n",
    "\n",
    "\n",
    "legend(\"bottomright\", \n",
    "  legend = c(\"rLasso\", \"Post-rLasso\", \"CV Lasso\"), \n",
    "  col = c(2,3,4), \n",
    "  pch = c(18,17, 19), \n",
    "  bty = \"n\", \n",
    "  pt.cex = 1.3, \n",
    "  cex = 1.2, \n",
    "  text.col = \"black\", \n",
    "  horiz = F , \n",
    "  inset = c(0.1, 0.1))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "papermill": {
     "duration": 0.018842,
     "end_time": "2021-02-15T11:02:51.941852",
     "exception": false,
     "start_time": "2021-02-15T11:02:51.923010",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Data Generating Process: Approximately Sparse + Small Dense Part"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.207598,
     "end_time": "2021-02-15T11:02:52.168536",
     "exception": false,
     "start_time": "2021-02-15T11:02:51.960938",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "set.seed(1)\n",
    "\n",
    "n = 100;\n",
    "p = 400;\n",
    "\n",
    "Z= runif(n)-1/2;\n",
    "W = matrix(runif(n*p)-1/2, n, p);\n",
    "\n",
    "\n",
    "beta = rnorm(p)*.2    # dense beta\n",
    "gX = exp(4*Z)+ W%*%beta;  # leading term nonlinear\n",
    "X = cbind(Z, Z^2, Z^3, W );  # polynomials in Zs will be approximating exp(4*Z)\n",
    "\n",
    "\n",
    "Y = gX + rnorm(n);    #generate Y\n",
    "\n",
    "\n",
    "plot(gX,Y, xlab=\"g(X)\", ylab=\"Y\")    #plot V vs g(X)\n",
    "\n",
    "print( c(\"theoretical R2:\", var(gX)/var(Y)))\n",
    "\n",
    "var(gX)/var(Y); #theoretical R-square in the simulation example\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 1.432822,
     "end_time": "2021-02-15T11:02:53.626802",
     "exception": false,
     "start_time": "2021-02-15T11:02:52.193980",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "library(glmnet)\n",
    "fit.lasso.cv   <- cv.glmnet(X, Y, family=\"gaussian\", alpha=1)  # family gaussian means that we'll be using square loss\n",
    "fit.ridge   <- cv.glmnet(X, Y, family=\"gaussian\", alpha=0)     # family gaussian means that we'll be using square loss\n",
    "fit.elnet   <- cv.glmnet(X, Y, family=\"gaussian\", alpha=.5)    # family gaussian means that we'll be using square loss\n",
    "\n",
    "yhat.lasso.cv    <- predict(fit.lasso.cv, newx = X)            # predictions\n",
    "yhat.ridge   <- predict(fit.ridge, newx = X)\n",
    "yhat.elnet   <- predict(fit.elnet, newx = X)\n",
    "\n",
    "MSE.lasso.cv <- summary(lm((gX-yhat.lasso.cv)^2~1))$coef[1:2]  # report MSE and standard error for MSE for approximating g(X)\n",
    "MSE.ridge <- summary(lm((gX-yhat.ridge)^2~1))$coef[1:2]        # report MSE and standard error for MSE for approximating g(X)\n",
    "MSE.elnet <- summary(lm((gX-yhat.elnet)^2~1))$coef[1:2]        # report MSE and standard error for MSE for approximating g(X)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 13.756606,
     "end_time": "2021-02-15T11:03:07.405363",
     "exception": false,
     "start_time": "2021-02-15T11:02:53.648757",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "library(hdm) \n",
    "fit.rlasso  <- rlasso(Y~X,  post=FALSE)      # lasso with plug-in penalty level\n",
    "fit.rlasso.post <- rlasso(Y~X,  post=TRUE)    # post-lasso with plug-in penalty level\n",
    "\n",
    "yhat.rlasso   <- predict(fit.rlasso)            #predict g(X) for values of X\n",
    "yhat.rlasso.post   <- predict(fit.rlasso.post)  #predict g(X) for values of X\n",
    "\n",
    "MSE.lasso <- summary(lm((gX-yhat.rlasso)^2~1))$coef[1:2]       # report MSE and standard error for MSE for approximating g(X)\n",
    "MSE.lasso.post <- summary(lm((gX-yhat.rlasso.post)^2~1))$coef[1:2]  # report MSE and standard error for MSE for approximating g(X)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 67.457885,
     "end_time": "2021-02-15T11:04:14.919093",
     "exception": false,
     "start_time": "2021-02-15T11:03:07.461208",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "library(glmnet)\n",
    "\n",
    "lava.predict<- function(X,Y, iter=5){\n",
    "    \n",
    "g1 = predict(rlasso(X, Y, post=F))  #lasso step fits \"sparse part\"\n",
    "m1 =  predict(glmnet(X, as.vector(Y-g1), family=\"gaussian\", alpha=0, lambda =20),newx=X ) #ridge step fits the \"dense\" part\n",
    "\n",
    "    \n",
    "i=1\n",
    "while(i<= iter) {\n",
    "g1 = predict(rlasso(X, Y, post=F))   #lasso step fits \"sparse part\"\n",
    "m1 = predict(glmnet(X, as.vector(Y-g1), family=\"gaussian\",  alpha=0, lambda =20),newx=X );  #ridge step fits the \"dense\" part\n",
    "i = i+1 }\n",
    "\n",
    "return(g1+m1);\n",
    "    }\n",
    "\n",
    "\n",
    "yhat.lava = lava.predict(X,Y)\n",
    "MSE.lava <- summary(lm((gX-yhat.lava)^2~1))$coef[1:2]       # report MSE and standard error for MSE for approximating g(X)\n",
    "\n",
    "    \n",
    "MSE.lava"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.072731,
     "end_time": "2021-02-15T11:04:15.033379",
     "exception": false,
     "start_time": "2021-02-15T11:04:14.960648",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "library(xtable)\n",
    "table<- matrix(0, 6, 2)\n",
    "table[1,1:2]   <- MSE.lasso.cv\n",
    "table[2,1:2]   <- MSE.ridge\n",
    "table[3,1:2]   <- MSE.elnet\n",
    "table[4,1:2]   <- MSE.lasso\n",
    "table[5,1:2]   <- MSE.lasso.post\n",
    "table[6,1:2]   <- MSE.lava\n",
    "\n",
    "colnames(table)<- c(\"MSA\", \"S.E. for MSA\")\n",
    "rownames(table)<- c(\"Cross-Validated Lasso\", \"Cross-Validated ridge\",\"Cross-Validated elnet\",\n",
    "                    \"Lasso\",\"Post-Lasso\",\"Lava\")\n",
    "tab <- xtable(table, digits =3)\n",
    "print(tab,type=\"latex\") # set type=\"latex\" for printing table in LaTeX\n",
    "tab\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.147036,
     "end_time": "2021-02-15T11:04:15.205679",
     "exception": false,
     "start_time": "2021-02-15T11:04:15.058643",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "plot(gX, gX, pch=19, cex=1, ylab=\"predicted value\", xlab=\"true g(X)\")\n",
    "\n",
    "points(gX, yhat.rlasso,   col=2, pch=18, cex = 1.5 )\n",
    "points(gX, yhat.elnet,  col=3, pch=17,  cex = 1.2  )\n",
    "points(gX, yhat.lava,  col=4, pch=19,  cex = 1.2 )\n",
    "\n",
    "\n",
    "legend(\"bottomright\", \n",
    "  legend = c(\"rLasso\", \"Elnet\", \"Lava\"), \n",
    "  col = c(2,3,4), \n",
    "  pch = c(18,17, 19), \n",
    "  bty = \"n\", \n",
    "  pt.cex = 1.3, \n",
    "  cex = 1.2, \n",
    "  text.col = \"black\", \n",
    "  horiz = F , \n",
    "  inset = c(0.1, 0.1))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "R",
   "language": "R",
   "name": "ir"
  },
  "language_info": {
   "codemirror_mode": "r",
   "file_extension": ".r",
   "mimetype": "text/x-r-source",
   "name": "R",
   "pygments_lexer": "r",
   "version": "3.6.3"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 157.995397,
   "end_time": "2021-02-15T11:04:16.324442",
   "environment_variables": {},
   "exception": null,
   "input_path": "__notebook__.ipynb",
   "output_path": "__notebook__.ipynb",
   "parameters": {},
   "start_time": "2021-02-15T11:01:38.329045",
   "version": "2.2.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
